import mindspore as ms
import mindspore.nn as nn
import numpy as np

from util.datasets import load_mnist
from util.transform import data_iter_2d
from util.callback import AccuracyMonitor
from mindvision.classification.dataset import Mnist


def main():
    batch_size, lr, num_epochs = 256, 1e-0, 100

    # dataset_train = Mnist(path='/shareData/mindspore-dataset/Mnist', split="train",
    #                       batch_size=batch_size, repeat_num=1, shuffle=True,
    #                       download=False, resize=28).run()
    # dataset_valid = Mnist(path='/shareData/mindspore-dataset/Mnist', split="test",
    #                       batch_size=batch_size, repeat_num=1, shuffle=True,
    #                       download=False, resize=28).run()

    features_t, labels_t = load_mnist('/shareData/mindspore-dataset/Mnist/train', reshape=True)
    features_v, labels_v = load_mnist('/shareData/mindspore-dataset/Mnist/test', split='t10k', reshape=True)

    dataset_train = data_iter_2d(features_t, labels_t, batch_size)
    dataset_valid = data_iter_2d(features_v, labels_v, batch_size)

    LeNet = nn.SequentialCell(
        nn.Conv2d(1, 6, 5, 1),
        nn.Sigmoid(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(6, 16, 5, 1, pad_mode='valid'),
        nn.Sigmoid(),
        nn.MaxPool2d(2, 2),
        nn.Flatten(),
        nn.Dense(400, 84),
        nn.ReLU(),
        nn.Dense(84, 10),
        nn.Softmax()
    )

    # X = ms.Tensor(np.ones((1, 1, 28, 28)), dtype=ms.float32)
    # y = LeNet(X)
    # print(y.shape)
    # return 0

    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    opti = nn.SGD(LeNet.trainable_params(), learning_rate=lr)

    model = ms.Model(LeNet, loss_fn=loss, optimizer=opti, metrics={'acc'})

    model.train(num_epochs, dataset_train, callbacks=[AccuracyMonitor(dataset_valid)])


if __name__ == '__main__':
    main()

"""
output:mindvision dataset
epoch:[1/10] Loss:2.3013227 Train Accuracy:0.1123798076923077 Valid Accuracy:0.11348157051282051
epoch:[2/10] Loss:2.3012316 Train Accuracy:0.11224626068376069 Valid Accuracy:0.11348157051282051
epoch:[3/10] Loss:2.3011854 Train Accuracy:0.1123798076923077 Valid Accuracy:0.11348157051282051
epoch:[4/10] Loss:2.301252 Train Accuracy:0.11232972756410256 Valid Accuracy:0.11368189102564102
epoch:[5/10] Loss:2.301314 Train Accuracy:0.11239650106837606 Valid Accuracy:0.11368189102564102
epoch:[6/10] Loss:2.3009107 Train Accuracy:0.1123798076923077 Valid Accuracy:0.11348157051282051
epoch:[7/10] Loss:2.2953918 Train Accuracy:0.11234642094017094 Valid Accuracy:0.11358173076923077
epoch:[8/10] Loss:1.9633884 Train Accuracy:0.5126368856837606 Valid Accuracy:0.5154246794871795
epoch:[9/10] Loss:1.6618159 Train Accuracy:0.8090945512820513 Valid Accuracy:0.8121995192307693
epoch:[10/10] Loss:1.6204048 Train Accuracy:0.8458700587606838 Valid Accuracy:0.84765625

output:self dataset
epoch:[1/100] Loss:2.301354 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[2/100] Loss:2.3013215 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[3/100] Loss:2.301274 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[4/100] Loss:2.301258 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[5/100] Loss:2.3012075 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[6/100] Loss:2.3013463 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[7/100] Loss:2.3012233 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[8/100] Loss:2.3012142 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[9/100] Loss:2.3011599 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[10/100] Loss:2.3011749 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[11/100] Loss:2.3009708 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[12/100] Loss:2.298999 Train Accuracy:0.11236666666666667 Valid Accuracy:0.1135
epoch:[13/100] Loss:2.2228186 Train Accuracy:0.20941666666666667 Valid Accuracy:0.2102
epoch:[14/100] Loss:2.055566 Train Accuracy:0.4046 Valid Accuracy:0.4015
epoch:[15/100] Loss:1.7609351 Train Accuracy:0.7164333333333334 Valid Accuracy:0.7259
epoch:[16/100] Loss:1.6771842 Train Accuracy:0.7960166666666667 Valid Accuracy:0.8056
epoch:[17/100] Loss:1.635459 Train Accuracy:0.8322333333333334 Valid Accuracy:0.8407
epoch:[18/100] Loss:1.6190662 Train Accuracy:0.8478833333333333 Valid Accuracy:0.8544
epoch:[19/100] Loss:1.6120905 Train Accuracy:0.8546166666666667 Valid Accuracy:0.8622
epoch:[20/100] Loss:1.6008794 Train Accuracy:0.8624333333333334 Valid Accuracy:0.868
epoch:[21/100] Loss:1.5960761 Train Accuracy:0.8679166666666667 Valid Accuracy:0.8712
epoch:[22/100] Loss:1.5989481 Train Accuracy:0.8654333333333334 Valid Accuracy:0.8695
epoch:[23/100] Loss:1.5616136 Train Accuracy:0.9092833333333333 Valid Accuracy:0.9138
epoch:[24/100] Loss:1.5198264 Train Accuracy:0.9487833333333333 Valid Accuracy:0.9538
epoch:[25/100] Loss:1.5260473 Train Accuracy:0.9419833333333333 Valid Accuracy:0.9422
epoch:[26/100] Loss:1.502354 Train Accuracy:0.9630666666666666 Valid Accuracy:0.9651
epoch:[27/100] Loss:1.4983165 Train Accuracy:0.9668666666666667 Valid Accuracy:0.9703
epoch:[28/100] Loss:1.4931138 Train Accuracy:0.97145 Valid Accuracy:0.9743
epoch:[29/100] Loss:1.4922625 Train Accuracy:0.9719 Valid Accuracy:0.9749
epoch:[30/100] Loss:1.5043103 Train Accuracy:0.9605833333333333 Valid Accuracy:0.9639
epoch:[31/100] Loss:1.4938688 Train Accuracy:0.9699166666666666 Valid Accuracy:0.9735
epoch:[32/100] Loss:1.4872283 Train Accuracy:0.9764166666666667 Valid Accuracy:0.9783
epoch:[33/100] Loss:1.4874744 Train Accuracy:0.9763 Valid Accuracy:0.9777
epoch:[34/100] Loss:1.4855503 Train Accuracy:0.9780666666666666 Valid Accuracy:0.9793
epoch:[35/100] Loss:1.5272132 Train Accuracy:0.93845 Valid Accuracy:0.9422
epoch:[36/100] Loss:1.4865075 Train Accuracy:0.97695 Valid Accuracy:0.9771
epoch:[37/100] Loss:1.4878961 Train Accuracy:0.9758 Valid Accuracy:0.976
epoch:[38/100] Loss:1.4829645 Train Accuracy:0.9804 Valid Accuracy:0.979
epoch:[39/100] Loss:1.4811144 Train Accuracy:0.98225 Valid Accuracy:0.9803
epoch:[40/100] Loss:1.4910432 Train Accuracy:0.97265 Valid Accuracy:0.9723
epoch:[41/100] Loss:1.4821796 Train Accuracy:0.981 Valid Accuracy:0.9793
epoch:[42/100] Loss:1.4814674 Train Accuracy:0.9816166666666667 Valid Accuracy:0.9795
epoch:[43/100] Loss:1.4863424 Train Accuracy:0.97675 Valid Accuracy:0.976
epoch:[44/100] Loss:1.478919 Train Accuracy:0.9839833333333333 Valid Accuracy:0.9808
epoch:[45/100] Loss:1.4789815 Train Accuracy:0.9839 Valid Accuracy:0.9806
epoch:[46/100] Loss:1.4780043 Train Accuracy:0.98505 Valid Accuracy:0.9815
epoch:[47/100] Loss:1.4795105 Train Accuracy:0.98345 Valid Accuracy:0.9804
epoch:[48/100] Loss:1.477847 Train Accuracy:0.9851 Valid Accuracy:0.9809
epoch:[49/100] Loss:1.4764723 Train Accuracy:0.9862833333333333 Valid Accuracy:0.9821
epoch:[50/100] Loss:1.4762201 Train Accuracy:0.9865666666666667 Valid Accuracy:0.9814
epoch:[51/100] Loss:1.4767867 Train Accuracy:0.9863 Valid Accuracy:0.9825
epoch:[52/100] Loss:1.4822782 Train Accuracy:0.9806 Valid Accuracy:0.9777
epoch:[53/100] Loss:1.4764909 Train Accuracy:0.9863 Valid Accuracy:0.9816
epoch:[54/100] Loss:1.4754539 Train Accuracy:0.9874166666666667 Valid Accuracy:0.9818
epoch:[55/100] Loss:1.4787201 Train Accuracy:0.9840833333333333 Valid Accuracy:0.9807
epoch:[56/100] Loss:1.4745486 Train Accuracy:0.9881666666666666 Valid Accuracy:0.9839
epoch:[57/100] Loss:1.4814782 Train Accuracy:0.98135 Valid Accuracy:0.979
epoch:[58/100] Loss:1.4753935 Train Accuracy:0.9875 Valid Accuracy:0.9834
epoch:[59/100] Loss:1.479766 Train Accuracy:0.9829833333333333 Valid Accuracy:0.9811
epoch:[60/100] Loss:1.4734492 Train Accuracy:0.9892666666666666 Valid Accuracy:0.984
epoch:[61/100] Loss:1.4732617 Train Accuracy:0.98935 Valid Accuracy:0.9844
epoch:[62/100] Loss:1.4747345 Train Accuracy:0.9879833333333333 Valid Accuracy:0.9838
epoch:[63/100] Loss:1.4728013 Train Accuracy:0.9899333333333333 Valid Accuracy:0.9848
epoch:[64/100] Loss:1.4775939 Train Accuracy:0.9852166666666666 Valid Accuracy:0.9805
epoch:[65/100] Loss:1.4731195 Train Accuracy:0.9893833333333333 Valid Accuracy:0.9851
epoch:[66/100] Loss:1.4720781 Train Accuracy:0.99045 Valid Accuracy:0.9847
epoch:[67/100] Loss:1.4722598 Train Accuracy:0.9903833333333333 Valid Accuracy:0.9845
epoch:[68/100] Loss:1.4728414 Train Accuracy:0.98965 Valid Accuracy:0.9842
epoch:[69/100] Loss:1.4723306 Train Accuracy:0.98995 Valid Accuracy:0.9856
epoch:[70/100] Loss:1.4727032 Train Accuracy:0.9899166666666667 Valid Accuracy:0.9837
epoch:[71/100] Loss:1.4721026 Train Accuracy:0.9903333333333333 Valid Accuracy:0.9851
epoch:[72/100] Loss:1.4716777 Train Accuracy:0.9907 Valid Accuracy:0.9861
epoch:[73/100] Loss:1.4723926 Train Accuracy:0.9901333333333333 Valid Accuracy:0.9843
epoch:[74/100] Loss:1.4711785 Train Accuracy:0.9913666666666666 Valid Accuracy:0.984
epoch:[75/100] Loss:1.471598 Train Accuracy:0.99085 Valid Accuracy:0.9857
epoch:[76/100] Loss:1.4707139 Train Accuracy:0.9918166666666667 Valid Accuracy:0.9858
epoch:[77/100] Loss:1.4720534 Train Accuracy:0.99065 Valid Accuracy:0.9851
epoch:[78/100] Loss:1.4705913 Train Accuracy:0.99175 Valid Accuracy:0.9855
epoch:[79/100] Loss:1.471405 Train Accuracy:0.9911333333333333 Valid Accuracy:0.9848
epoch:[80/100] Loss:1.4707181 Train Accuracy:0.9916333333333334 Valid Accuracy:0.9856
epoch:[81/100] Loss:1.4708439 Train Accuracy:0.99165 Valid Accuracy:0.9857
epoch:[82/100] Loss:1.4735742 Train Accuracy:0.98875 Valid Accuracy:0.9852
epoch:[83/100] Loss:1.4723717 Train Accuracy:0.9898833333333333 Valid Accuracy:0.9856
epoch:[84/100] Loss:1.4718761 Train Accuracy:0.9906666666666667 Valid Accuracy:0.9843
epoch:[85/100] Loss:1.4706904 Train Accuracy:0.9919 Valid Accuracy:0.9855
epoch:[86/100] Loss:1.4695128 Train Accuracy:0.9927333333333334 Valid Accuracy:0.9858
epoch:[87/100] Loss:1.4697739 Train Accuracy:0.9925 Valid Accuracy:0.9855
epoch:[88/100] Loss:1.470723 Train Accuracy:0.9917 Valid Accuracy:0.985
epoch:[89/100] Loss:1.4694605 Train Accuracy:0.9928166666666667 Valid Accuracy:0.9863
epoch:[90/100] Loss:1.4693828 Train Accuracy:0.9927833333333334 Valid Accuracy:0.9854
epoch:[91/100] Loss:1.4718018 Train Accuracy:0.9907166666666667 Valid Accuracy:0.9857
epoch:[92/100] Loss:1.4696373 Train Accuracy:0.9927 Valid Accuracy:0.9854
epoch:[93/100] Loss:1.4696767 Train Accuracy:0.99245 Valid Accuracy:0.9846
epoch:[94/100] Loss:1.4692967 Train Accuracy:0.9928 Valid Accuracy:0.9861
epoch:[95/100] Loss:1.4690297 Train Accuracy:0.9929833333333333 Valid Accuracy:0.9856
epoch:[96/100] Loss:1.4693991 Train Accuracy:0.9927666666666667 Valid Accuracy:0.9857
epoch:[97/100] Loss:1.4686409 Train Accuracy:0.99335 Valid Accuracy:0.9868
epoch:[98/100] Loss:1.4688703 Train Accuracy:0.9932333333333333 Valid Accuracy:0.9865
epoch:[99/100] Loss:1.4685272 Train Accuracy:0.9934333333333333 Valid Accuracy:0.9857
epoch:[100/100] Loss:1.4686333 Train Accuracy:0.99345 Valid Accuracy:0.9852
"""
